from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.token_embedders import Embedding, PretrainedTransformerEmbedder, ElmoTokenEmbedder, \
    TokenCharactersEncoder
from allennlp.modules.seq2vec_encoders import CnnEncoder
from allennlp.common import Params


class EmbedderFactory():
    def __init__(self, model_name, d_embedding):
        self.model_name = model_name
        self.d_embedding = d_embedding

    def get_embedder(self, vocab: Vocabulary):
        if self.model_name in ['word']:
            return {'tokens': Embedding(num_embeddings=vocab.get_vocab_size(),
                                        embedding_dim=300)}
        elif self.model_name == 'glove':
            return {'tokens':
                        Embedding.from_params(vocab=vocab,
                                              params=Params(
                                                  {'pretrained_file': f'{self.d_embedding}/glove/glove.6B.300d.txt',
                                                   'embedding_dim': 300}))}
        elif self.model_name == 'word2vec':
            return {'tokens':
                        Embedding.from_params(vocab=vocab,
                                              params=Params({
                                                  'pretrained_file': f'{self.d_embedding}/word2vec/GoogleNews-vectors-negative300.txt',
                                                  'embedding_dim': 300}))}
        elif self.model_name == 'fasttext':
            return {'tokens':
                        Embedding.from_params(vocab=vocab,
                                              params=Params({
                                                  'pretrained_file': f'{self.d_embedding}/fasttext/wiki-news-300d-1M.vec',
                                                  'embedding_dim': 300}))}
        elif self.model_name == 'char':
            character_embedding = Embedding(num_embeddings=vocab.get_vocab_size('token_characters'), embedding_dim=70)
            cnn_encoder = CnnEncoder(embedding_dim=70, num_filters=300, ngram_filter_sizes=(6,))
            return {
                'token_characters': TokenCharactersEncoder(character_embedding, cnn_encoder)
            }
        elif self.model_name == 'elmo':
            return {'tokens': Embedding(num_embeddings=vocab.get_vocab_size(),
                                        embedding_dim=300),
                    'token_characters':
                        ElmoTokenEmbedder(
                            options_file=f'{self.d_embedding}/elmo/elmo_2x1024_128_2048cnn_1xhighway_options.json',
                            weight_file=f'{self.d_embedding}/elmo/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5')}
        elif 'bert' in self.model_name:
            return {'tokens':
                        PretrainedTransformerEmbedder(model_name=self.model_name,
                                                      train_parameters=True)}
        else:
            raise ValueError(f'{self.model_name} embedder not supported!')
